{
	"cells": [
		{
			"cell_type": "markdown",
			"metadata": {},
			"source": [
				"# Super charge your LLMs with RAG at scale using AWS Glue for Apache Spark"
			]
		},
		{
			"cell_type": "markdown",
			"metadata": {},
			"source": [
				"This sample notebook demonstrates how to build RAG using AWS Glue for Apache Spark. In this sample, we leverage LangChain integrating with AWS Glue for Apache Spark, Amazon SageMaker JumpStart, and Amazon OpenSearch Serverless. To make this solution scalable and customizable, we leverage Apache Spark’s distributed fashion and PySpark’s flexible scripting capabilities. We use OpenSearch Serverless as a sample vector store, and use Llama 3.1 model provided in SageMaker JumpStart."
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"tags": [],
				"trusted": true,
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": [
				"%session_id_prefix langchain-spark-rag-jumpstart-\n",
				"%glue_version 4.0\n",
				"%idle_timeout 2880\n",
				"%worker_type G.1X\n",
				"%number_of_workers 5"
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"tags": [],
				"trusted": true,
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": [
				"%additional_python_modules langchain==0.2.12, langchain-community==0.2.11, langchain-aws==0.1.16, typing-extensions, typing-inspect, beautifulsoup4, lxml, markdownify, pydantic, boto3, botocore, opensearch-py, sagemaker"
			]
		},
		{
			"cell_type": "markdown",
			"metadata": {},
			"source": [
				"## Vectorstore setup"
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"tags": [],
				"trusted": true,
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": [
				"import json\n",
				"import os\n",
				"os.environ[\"AWS_DEFAULT_REGION\"] = \"<region>\"\n",
				"iam_role_arn = \"<your-iam-role-arn>\""
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"tags": [],
				"trusted": true,
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": [
				"import boto3\n",
				"import time\n",
				"vector_store_name = 'langchain-spark-rag'\n",
				"index_name = \"langchain-spark-rag-index\"\n",
				"encryption_policy_name = \"langchain-spark-rag-sp\"\n",
				"network_policy_name = \"langchain-spark-rag-np\"\n",
				"access_policy_name = 'langchain-spark-rag-ap'\n",
				"identity = iam_role_arn\n",
				"\n",
				"aoss_client = boto3.client('opensearchserverless')\n",
				"\n",
				"security_policy = aoss_client.create_security_policy(\n",
				"    name = encryption_policy_name,\n",
				"    policy = json.dumps(\n",
				"        {\n",
				"            'Rules': [{'Resource': ['collection/' + vector_store_name],\n",
				"            'ResourceType': 'collection'}],\n",
				"            'AWSOwnedKey': True\n",
				"        }),\n",
				"    type = 'encryption'\n",
				")\n",
				"\n",
				"network_policy = aoss_client.create_security_policy(\n",
				"    name = network_policy_name,\n",
				"    policy = json.dumps(\n",
				"        [\n",
				"            {'Rules': [{'Resource': ['collection/' + vector_store_name],\n",
				"            'ResourceType': 'collection'}],\n",
				"            'AllowFromPublic': True}\n",
				"        ]),\n",
				"    type = 'network'\n",
				")\n",
				"\n",
				"collection = aoss_client.create_collection(name=vector_store_name,type='VECTORSEARCH')\n",
				"\n",
				"while True:\n",
				"    status = aoss_client.list_collections(collectionFilters={'name':vector_store_name})['collectionSummaries'][0]['status']\n",
				"    if status in ('ACTIVE', 'FAILED'): break\n",
				"    time.sleep(10)\n",
				"\n",
				"access_policy = aoss_client.create_access_policy(\n",
				"    name = access_policy_name,\n",
				"    policy = json.dumps(\n",
				"        [\n",
				"            {\n",
				"                'Rules': [\n",
				"                    {\n",
				"                        'Resource': ['collection/' + vector_store_name],\n",
				"                        'Permission': [\n",
				"                            'aoss:CreateCollectionItems',\n",
				"                            'aoss:DeleteCollectionItems',\n",
				"                            'aoss:UpdateCollectionItems',\n",
				"                            'aoss:DescribeCollectionItems'],\n",
				"                        'ResourceType': 'collection'\n",
				"                    },\n",
				"                    {\n",
				"                        'Resource': ['index/' + vector_store_name + '/*'],\n",
				"                        'Permission': [\n",
				"                            'aoss:CreateIndex',\n",
				"                            'aoss:DeleteIndex',\n",
				"                            'aoss:UpdateIndex',\n",
				"                            'aoss:DescribeIndex',\n",
				"                            'aoss:ReadDocument',\n",
				"                            'aoss:WriteDocument'],\n",
				"                        'ResourceType': 'index'\n",
				"                    }],\n",
				"                'Principal': [identity],\n",
				"                'Description': 'Easy data policy'}\n",
				"        ]),\n",
				"    type = 'data'\n",
				")\n",
				"\n",
				"host = \"https://\" + collection['createCollectionDetail']['id'] + '.' + os.environ.get(\"AWS_DEFAULT_REGION\", None) + '.aoss.amazonaws.com'\n"
			]
		},
		{
			"cell_type": "markdown",
			"metadata": {},
			"source": [
				"The above cell takes a few minutes (~3 minutes) to complete."
			]
		},
		{
			"cell_type": "markdown",
			"metadata": {},
			"source": [
				"## Sample document download"
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"tags": [],
				"trusted": true,
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": [
				"import requests\n",
				"\n",
				"content = requests.get(\"https://lilianweng.github.io/posts/2023-06-23-agent/\")\n",
				"\n",
				"s3_resource = boto3.resource(\"s3\")\n",
				"\n",
				"account_id = boto3.client('sts').get_caller_identity()['Account']\n",
				"region = os.environ.get(\"AWS_DEFAULT_REGION\", None)\n",
				"bucket_name = f'langchain-spark-rag-jumpstart-{account_id}-{region}'\n",
				"\n",
				"bucket = s3_resource.Bucket(bucket_name)\n",
				"# bucket.create()\n",
				"if region == 'us-east-1':\n",
				"    bucket.create()\n",
				"else:\n",
				"    bucket.create(CreateBucketConfiguration={'LocationConstraint': region})\n",
				"\n",
				"key = \"data/langchain-spark-rag/content.html\"\n",
				"obj = s3_resource.Object(bucket_name, key)\n",
				"obj.put(Body=content.text)"
			]
		},
		{
			"cell_type": "markdown",
			"metadata": {},
			"source": [
				"## Document preparation"
			]
		},
		{
			"cell_type": "markdown",
			"metadata": {},
			"source": [
				"#### Read HTML file"
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"tags": [],
				"trusted": true,
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": [
				"df_html = spark.read.text(f\"s3://{bucket_name}/{key}\", wholetext=True)\n",
				"df_html.show()"
			]
		},
		{
			"cell_type": "markdown",
			"metadata": {},
			"source": [
				"#### Parse and clean up HTML"
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"tags": [],
				"trusted": true,
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": [
				"from bs4 import BeautifulSoup\n",
				"from markdownify import MarkdownConverter\n",
				"\n",
				"def parse_html(html):\n",
				"    soup = BeautifulSoup(html, \"lxml\")\n",
				"    soup.smooth()\n",
				"    md = MarkdownConverter(heading_style=\"ATX\").convert_soup(soup)\n",
				"    return format_md(md)\n"
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"tags": [],
				"trusted": true,
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": [
				"import re\n",
				"\n",
				"def format_md(md):\n",
				"    # Remove extra preceding white spaces in a line\n",
				"    md = re.sub(r\"^ +\", \"\", md, flags=re.MULTILINE)\n",
				"    # Replace all newlines with spaces\n",
				"    md = re.sub(r\"\\n\", \" \", md)\n",
				"    # Remove redundant spaces\n",
				"    md = re.sub(r\"\\s\\s+\", \" \", md)\n",
				"    # Convert \"\\_\" to \"_\"\n",
				"    md = re.sub(r\"\\\\+_\", \"_\", md)\n",
				"\n",
				"    return md"
			]
		},
		{
			"cell_type": "markdown",
			"metadata": {},
			"source": [
				"#### Chunking HTML"
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"tags": [],
				"trusted": true,
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": [
				"from langchain.text_splitter import MarkdownTextSplitter\n",
				"\n",
				"def chunk_md(md):\n",
				"    splitter = MarkdownTextSplitter(chunk_size=1000, chunk_overlap=100)\n",
				"    return splitter.split_text(md)"
			]
		},
		{
			"cell_type": "markdown",
			"metadata": {},
			"source": [
				"#### Embedding"
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"tags": [],
				"trusted": true,
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": [
				"_MODEL_CONFIG_ = {\n",
				"    \"huggingface-sentencesimilarity-all-MiniLM-L6-v2\": {\n",
				"        \"instance type\": \"ml.g5.2xlarge\",\n",
				"        \"env\": {\"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\", \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n",
				"    },\n",
				"    \"meta-textgeneration-llama-3-1-8b-instruct\": {\n",
				"        \"instance type\": \"ml.g5.2xlarge\",\n",
				"        \"env\": {\"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\", \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n",
				"    }\n",
				"}"
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"tags": [],
				"trusted": true,
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": [
				"from sagemaker.jumpstart.model import JumpStartModel\n",
				"from sagemaker.utils import name_from_base\n",
				"\n",
				"newline, bold, unbold = \"\\n\", \"\\033[1m\", \"\\033[0m\"\n",
				"\n",
				"for model_id in _MODEL_CONFIG_:\n",
				"    endpoint_name = name_from_base(f\"jumpstart-example-raglc-{model_id}\")\n",
				"    inference_instance_type = _MODEL_CONFIG_[model_id][\"instance type\"]\n",
				"\n",
				"    try:\n",
				"        # Create and deploy the JumpStart model\n",
				"        model = JumpStartModel(\n",
				"            model_id=model_id,\n",
				"            role=iam_role_arn,\n",
				"            instance_type=inference_instance_type,\n",
				"            predictor_cls=None,  # Use default predictor\n",
				"            env=_MODEL_CONFIG_[model_id][\"env\"]\n",
				"        )\n",
				"\n",
				"        predictor = model.deploy(\n",
				"            initial_instance_count=1,\n",
				"            instance_type=inference_instance_type,\n",
				"            endpoint_name=endpoint_name,\n",
				"            accept_eula=True\n",
				"        )\n",
				"\n",
				"        print(f\"{bold}Model {model_id} has been deployed successfully.{unbold}{newline}\")\n",
				"        _MODEL_CONFIG_[model_id][\"endpoint_name\"] = endpoint_name\n",
				"\n",
				"    except Exception as e:\n",
				"        print(f\"Error deploying model {model_id}: {str(e)}\")\n",
				"        continue"
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"tags": [],
				"trusted": true,
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": [
				"from langchain_community.embeddings import SagemakerEndpointEmbeddings\n",
				"from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler\n",
				"from langchain_community.vectorstores import OpenSearchVectorSearch\n",
				"from opensearchpy import RequestsHttpConnection, AWSV4SignerAuth\n",
				"\n",
				"service = \"aoss\"\n",
				"region = os.environ.get(\"AWS_DEFAULT_REGION\", None)\n",
				"\n",
				"def process_batch(chunks):\n",
				"    from typing import List, Dict, Any\n",
				"    credentials = boto3.Session().get_credentials()\n",
				"    awsauth = AWSV4SignerAuth(credentials, region, service)\n",
				"\n",
				"    class ContentHandler(EmbeddingsContentHandler):\n",
				"        content_type = \"application/json\"\n",
				"        accepts = \"application/json\"\n",
				"\n",
				"        def transform_input(self, prompts: List[str], model_kwargs: Dict) -> bytes:\n",
				"            input_str = json.dumps({\"text_inputs\": prompts, \"mode\": \"embedding\", **model_kwargs})\n",
				"            return input_str.encode('utf-8')\n",
				"\n",
				"        def transform_output(self, output: bytes) -> List[List[float]]:\n",
				"            response_json = json.loads(output.read().decode(\"utf-8\"))\n",
				"            return response_json[\"embedding\"]\n",
				"\n",
				"    embeddings = SagemakerEndpointEmbeddings(\n",
				"        endpoint_name=_MODEL_CONFIG_[\"huggingface-sentencesimilarity-all-MiniLM-L6-v2\"][\"endpoint_name\"],\n",
				"        region_name=region,\n",
				"        content_handler=ContentHandler(),\n",
				"    )\n",
				"\n",
				"    vectorstore = OpenSearchVectorSearch.from_texts(\n",
				"        chunks, \n",
				"        embeddings,\n",
				"        index_name=index_name, \n",
				"        opensearch_url=host, \n",
				"        http_auth=awsauth,\n",
				"        timeout=100,\n",
				"        use_ssl=True,\n",
				"        verify_certs=True,\n",
				"        connection_class=RequestsHttpConnection\n",
				"    )\n",
				"    ret = vectorstore.add_texts(texts=chunks)\n",
				"    return {\"num_success\": len(ret)}"
			]
		},
		{
			"cell_type": "markdown",
			"metadata": {},
			"source": [
				"#### Pre-process HTML document"
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"tags": [],
				"trusted": true,
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": [
				"from pyspark.sql.functions import col, udf\n",
				"\n",
				"def process_html(html):\n",
				"    cleaned_md = parse_html(html)\n",
				"    chunks = chunk_md(cleaned_md)\n",
				"    process_batch(chunks)\n",
				"    return cleaned_md\n",
				"\n",
				"process_html_udf = udf(lambda z: process_html(z))"
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"tags": [],
				"trusted": true,
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": [
				"df_html_processed = df_html.select(process_html_udf(col(\"value\")).alias(\"text\"))\n",
				"df_html_processed.show()"
			]
		},
		{
			"cell_type": "markdown",
			"metadata": {},
			"source": [
				"## Question answering"
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"tags": [],
				"trusted": true,
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": [
				"import logging\n",
				"from typing import List, Dict, Any\n",
				"from langchain_community.llms import SagemakerEndpoint\n",
				"from langchain_community.llms.sagemaker_endpoint import LLMContentHandler\n",
				"\n",
				"credentials = boto3.Session().get_credentials()\n",
				"awsauth = AWSV4SignerAuth(credentials, region, service)\n",
				"\n",
				"sagemaker_client = boto3.client('sagemaker-runtime')\n",
				"\n",
				"\n",
				"class EmbeddingsContentHandler(EmbeddingsContentHandler):\n",
				"    content_type = \"application/json\"\n",
				"    accepts = \"application/json\"\n",
				"\n",
				"    def transform_input(self, prompts: List[str], model_kwargs: Dict) -> bytes:\n",
				"        input_str = json.dumps({\"text_inputs\": prompts, \"mode\": \"embedding\", **model_kwargs})\n",
				"        return input_str.encode('utf-8')\n",
				"\n",
				"    def transform_output(self, output: bytes) -> List[List[float]]:\n",
				"        response_json = json.loads(output.read().decode(\"utf-8\"))\n",
				"        return response_json[\"embedding\"]\n",
				"\n",
				"class LLMContentHandler(LLMContentHandler):\n",
				"    content_type = \"application/json\"\n",
				"    accepts = \"application/json\"\n",
				"\n",
				"    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
				"        input_str = json.dumps({\n",
				"            \"inputs\": prompt,\n",
				"            \"parameters\": {\n",
				"                \"max_new_tokens\": 256,\n",
				"                \"top_p\": 0.9,\n",
				"                \"temperature\": model_kwargs.get(\"temperature\", 0.6),\n",
				"            }\n",
				"        })\n",
				"        return input_str.encode('utf-8')\n",
				"\n",
				"    def transform_output(self, output: bytes) -> str:\n",
				"        try:\n",
				"            response_json = json.loads(output.read().decode(\"utf-8\"))\n",
				"            logging.info(f\"Raw output: {response_json}\")\n",
				"            \n",
				"            if isinstance(response_json, list):\n",
				"                if response_json and isinstance(response_json[0], dict):\n",
				"                    return response_json[0].get(\"generated_text\", str(response_json[0]))\n",
				"                else:\n",
				"                    return str(response_json[0]) if response_json else \"\"\n",
				"            elif isinstance(response_json, dict):\n",
				"                return response_json.get(\"generated_text\", str(response_json))\n",
				"            else:\n",
				"                return str(response_json)\n",
				"        except Exception as e:\n",
				"            logging.error(f\"Error in transform_output: {str(e)}\")\n",
				"            logging.error(f\"Raw output: {output}\")\n",
				"            raise\n",
				"\n",
				"embeddings = SagemakerEndpointEmbeddings(\n",
				"    endpoint_name=_MODEL_CONFIG_[\"huggingface-sentencesimilarity-all-MiniLM-L6-v2\"][\"endpoint_name\"],\n",
				"    region_name=region,\n",
				"    content_handler=EmbeddingsContentHandler(),\n",
				")\n",
				"\n",
				"opensearch_vector_search_client = OpenSearchVectorSearch(\n",
				"    index_name=index_name,\n",
				"    embedding_function=embeddings,\n",
				"    opensearch_url=host,\n",
				"    http_auth=awsauth,\n",
				"    timeout=100,\n",
				"    use_ssl=True,\n",
				"    verify_certs=True,\n",
				"    connection_class=RequestsHttpConnection,\n",
				"    vector_field=\"vector_field\",\n",
				"    engine=\"faiss\"\n",
				")\n",
				"\n",
				"llm = SagemakerEndpoint(\n",
				"    endpoint_name=_MODEL_CONFIG_[\"meta-textgeneration-llama-3-1-8b-instruct\"][\"endpoint_name\"],\n",
				"    client=sagemaker_client,\n",
				"    model_kwargs={\"temperature\": 1e-10},\n",
				"    content_handler=LLMContentHandler(),\n",
				")"
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"tags": [],
				"trusted": true,
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": [
				"from langchain.chains import RetrievalQA\n",
				"qa_aoss = RetrievalQA.from_chain_type(\n",
				"    llm=llm,\n",
				"    chain_type=\"stuff\",\n",
				"    retriever=opensearch_vector_search_client.as_retriever(),\n",
				"    return_source_documents=True,\n",
				")"
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"tags": [],
				"trusted": true,
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": [
				"from langchain.load.dump import dumps\n",
				"query = \"What is Task Decomposition?\"\n",
				"results = opensearch_vector_search_client.similarity_search(query)\n",
				"print(dumps(results, pretty=True))"
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"tags": [],
				"trusted": true,
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": [
				"qa_aoss.invoke(query)"
			]
		},
		{
			"cell_type": "markdown",
			"metadata": {},
			"source": [
				"## Clean up"
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"tags": [],
				"trusted": true,
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": [
				"## S3\n",
				"bucket.object_versions.delete()\n",
				"bucket.delete()\n",
				"\n",
				"## OpenSearch Serverless\n",
				"aoss_client.delete_collection(id=collection['createCollectionDetail']['id'])\n",
				"aoss_client.delete_access_policy(name=access_policy_name, type='data')\n",
				"aoss_client.delete_security_policy(name=encryption_policy_name, type='encryption')\n",
				"aoss_client.delete_security_policy(name=network_policy_name, type='network')\n",
				"\n",
				"## SageMaker\n",
				"sagemaker_client = boto3.client('sagemaker')\n",
				"for model_id in _MODEL_CONFIG_:\n",
				"    sagemaker_client.delete_endpoint(EndpointName=_MODEL_CONFIG_[model_id][\"endpoint_name\"])"
			]
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": []
		},
		{
			"cell_type": "code",
			"execution_count": null,
			"metadata": {
				"vscode": {
					"languageId": "python_glue_session"
				}
			},
			"outputs": [],
			"source": []
		}
	],
	"metadata": {
		"kernelspec": {
			"display_name": "Glue PySpark",
			"language": "python",
			"name": "glue_pyspark"
		},
		"language_info": {
			"codemirror_mode": {
				"name": "python",
				"version": 3
			},
			"file_extension": ".py",
			"mimetype": "text/x-python",
			"name": "Python_Glue_Session",
			"pygments_lexer": "python3"
		}
	},
	"nbformat": 4,
	"nbformat_minor": 4
}
